/*
 * Copyright (C) 2016 Edward Raff
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package com.jstarcraft.ai.jsat.classifiers.svm;

import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import com.jstarcraft.ai.jsat.classifiers.CategoricalResults;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.Classifier;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.classifiers.calibration.BinaryScoreClassifier;
import com.jstarcraft.ai.jsat.clustering.kmeans.ElkanKernelKMeans;
import com.jstarcraft.ai.jsat.clustering.kmeans.KernelKMeans;
import com.jstarcraft.ai.jsat.distributions.kernels.KernelTrick;
import com.jstarcraft.ai.jsat.distributions.kernels.RBFKernel;
import com.jstarcraft.ai.jsat.exceptions.UntrainedModelException;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.parameters.Parameter;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.utils.ListUtils;
import com.jstarcraft.ai.jsat.utils.concurrent.ParallelUtils;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;

/**
 * This is an implementation of the Divide-and-Conquer Support Vector Machine
 * (DC-SVM). It uses a a combination of clustering and warm-starting to train
 * faster, as well as an early stopping strategy to provide a fast approximate
 * SVM solution. The final accuracy should often be at or near that of a normal
 * SVM, while being faster to train. <br>
 * <br>
 * The current implementation is based on {@link SVMnoBias}, meaning this code
 * does not have a bias term and it only works with normalized kernels. Any
 * non-normalized kernel will be normalized automatically. This is not a problem
 * for the common RBF kernel.<br>
 * <br>
 *
 * See:
 * <ul>
 * <li>Hsieh, C.-J., Si, S., & Dhillon, I. S. (2014). <i>A Divide-and-Conquer
 * Solver for Kernel Support Vector Machines</i>. In Proceedings of the 31st
 * International Conference on Machine Learning. Beijing, China.</li>
 * </ul>
 *
 * @author Edward Raff
 */
public class DCSVM extends SupportVectorLearner implements Classifier, Parameterized, BinaryScoreClassifier {
    private double C = 1;
    private double tolerance = 1e-3;

    private KernelKMeans clusters;
    private int m = 2000;
    private int l_max = 4;
    private int l_early = 3;
    private int k = 4;

    private Map<Integer, SVMnoBias> early_models;
    private long cache_size = 0;

    /**
     * Creates a new DC-SVM for the given kernel
     * 
     * @param k the kernel to use
     */
    public DCSVM(KernelTrick k) {
        super(k, CacheMode.ROWS);
        this.cache_size = Runtime.getRuntime().freeMemory() / 2;
    }

    /**
     * Creates a new DC-SVM for the RBF kernel
     */
    public DCSVM() {
        this(new RBFKernel());
    }

    /**
     * Copy Constructor
     * 
     * @param toCopy object to copy
     */
    public DCSVM(DCSVM toCopy) {
        super(toCopy);
        this.C = toCopy.C;
        this.tolerance = toCopy.tolerance;
        if (toCopy.clusters != null)
            this.clusters = toCopy.clusters.clone();
        this.cache_size = toCopy.cache_size;
        this.m = toCopy.m;
        this.l_early = toCopy.l_early;
        this.l_max = toCopy.l_max;
        this.k = toCopy.k;
        if (toCopy.early_models != null) {
            this.early_models = new ConcurrentHashMap<Integer, SVMnoBias>();
            for (Map.Entry<Integer, SVMnoBias> x : toCopy.early_models.entrySet())
                this.early_models.put(x.getKey(), x.getValue().clone());
        }
    }

    /**
     * The DC-SVM algorithm works by creating a hierarchy of levels, and iteratively
     * refining the solution from one level to the next. Level 0 corresponds to the
     * exact SVM solution, and higher levels are courser approximations. This method
     * controls which level the training starts at.
     * 
     * @param l_max which level to start the training at.
     */
    public void setStartLevel(int l_max) {
        if (l_max < 0)
            throw new IllegalArgumentException("l_max must be a non-negative integer, not " + l_max);
        this.l_max = l_max;
    }

    /**
     * 
     * @return which level to start the training at.
     */
    public int getStartLevel() {
        return l_max;
    }

    /**
     * The DC-SVM algorithm works by creating a hierarchy of levels, and iteratively
     * refining the solution from one level to the next. Level 0 corresponds to the
     * exact SVM solution, and higher levels are courser approximations. This method
     * controls which level the training stops at, with 0 being the latest it can
     * stop. The default stopping level is 3.
     *
     * @param l_early which level to stop the training at, and use for
     *                classification.
     */
    public void setEndLevel(int l_early) {
        if (l_early < 0)
            throw new IllegalArgumentException("l_early must be a non-negative integer, not " + l_early);
        this.l_early = l_early;
    }

    /**
     * 
     * @return which level to stop the training at, and use for classification.
     */
    public int getEndLevel() {
        return l_early;
    }

    /**
     * At each level of the DC-SVM training, a clustering algorithm is used to
     * divide the dataset into sub-groups for independent training. Increasing the
     * number of points used for clustering improves model accuracy, but also
     * increases training time. The default value is 2000. This value may need to be
     * increased if using a higher staring level.
     *
     * @param m the number of data points to sample for each cluster size
     */
    public void setClusterSampleSize(int m) {
        if (m <= 0)
            throw new IllegalArgumentException("Cluster Sample Size must be a positive integer, not " + m);
        this.m = m;
    }

    public int getClusterSampleSize() {
        return m;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);

        double sum = getScore(data);

        if (sum > 0)
            cr.setProb(1, 1);
        else
            cr.setProb(0, 1);

        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        if (vecs == null)
            throw new UntrainedModelException("Classifier has yet to be trained");

        Vec x = dp.getNumericalValues();
        int c;
        if (early_models.size() > 1)
            c = clusters.findClosestCluster(x, getKernel().getQueryInfo(x));
        else
            c = 0;
        return early_models.get(c).getScore(dp);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        final int N = dataSet.size();
        vecs = dataSet.getDataVectors();
        early_models = new ConcurrentHashMap<>();
//        weights = dataSet.getDataWeights();
//        label = new short[N];
//        for(int i = 0; i < N; i++)
//            label[i] = (short) (dataSet.getDataPointCategory(i)*2-1);
        setCacheMode(CacheMode.NONE);// Initiates the accel cache
        // initialize alphas array to all zero
        alphas = new double[N];// zero is default value

        /**
         * Used to keep track of which sub cluster each training datapoint belongs to
         */
        final int[] group = new int[N];

        /**
         * Used to select subsamples of data points for clustering, and to map them back
         * to their original indicies
         */
        IntArrayList indicies = new IntArrayList();
        // for l = lmax, . . . , 1 do
        for (int l = l_max; l >= l_early; l--) {
//            System.out.println("Level " + l);
            early_models.clear();
            // sub-sampled dataset to use for clustering
            ClassificationDataSet toCluster = new ClassificationDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories(), dataSet.getPredicting());
            // Set number of clusters in the current level k_l = k^l
            int k_l = (int) Math.pow(k, l);

            // number of datapoints to use in clustering
            // increase M = m by default. Increase to M=7 m if less than 7 points per
            // cluster
            int M;
            if (N / k_l < 7)
                M = k_l * 7;
            else
                M = m;

            if (l == l_max) {
                ListUtils.addRange(indicies, 0, N, 1);
                Collections.shuffle(indicies);
                for (int i = 0; i < Math.min(M, N); i++)
                    toCluster.addDataPoint(dataSet.getDataPoint(i), dataSet.getDataPointCategory(i));
            } else {
                indicies.clear();
                for (int i = 0; i < N; i++)
                    if (alphas[i] != 0)
                        indicies.add(i);
                Collections.shuffle(indicies);
                for (int i = 0; i < Math.min(M, indicies.size()); i++)
                    toCluster.addDataPoint(dataSet.getDataPoint(i), dataSet.getDataPointCategory(i));
            }
            // Run kernel kmeans on {xi1, . . . ,xim} to get cluster centers c1, . . . , ckl
            // ;
            clusters = new ElkanKernelKMeans(getKernel());
            clusters.setMaximumIterations(100);
//            System.out.println("Finding " + k_l + " clusters");
            k_l = Math.min(k_l, toCluster.size() / 2);// Few support vectors? Make clustering smaller then
            int[] sub_results;
            if (k_l <= 1)// dont run cluster, we are doing final refinement step!
            {
                sub_results = new int[N];// will be all 0, for 1 'cluster'
                indicies.clear();
                ListUtils.addRange(indicies, 0, N, 1);
            } else
                sub_results = clusters.cluster(toCluster, k_l, parallel, (int[]) null);

            // create partitioning
            // First, don't bother with distance computations for people we just clustered
            Arrays.fill(group, -1);
            IntOpenHashSet found_clusters = new IntOpenHashSet(k_l);
            for (int i = 0; i < sub_results.length; i++) {
                group[indicies.getInt(i)] = sub_results[i];
                found_clusters.add(sub_results[i]);
            }
            // find who everyone else belongs to
            ParallelUtils.run(parallel, N, (i) -> {
                if (group[i] >= 0)
                    return;// you already got assigned above

                DoubleList qi = null;
                if (accelCache != null) {
                    int multiplier = accelCache.size() / N;
                    qi = accelCache.subList(i * multiplier, i * multiplier + multiplier);
                }
                group[i] = clusters.findClosestCluster(vecs.get(i), qi);
            });
            // everyone has now been assigned to their closest cluster

            // build SVM model for each cluster
            for (int c : found_clusters) {
//                System.out.println("\tBuilding model for " + c);
                ClassificationDataSet V_c = new ClassificationDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories(), dataSet.getPredicting());
                DoubleArrayList V_alphas = new DoubleArrayList();
                IntArrayList orig_index = new IntArrayList();
                for (int i = 0; i < N; i++) {
                    if (group[i] != c)
                        continue;// well get to you later
                    // else, create dataset
                    V_c.addDataPoint(dataSet.getDataPoint(i), dataSet.getDataPointCategory(i));
                    V_alphas.add(Math.abs(alphas[i]));
                    orig_index.add(i);
                }

                SVMnoBias svm = new SVMnoBias(getKernel());
                if (cache_size > 0)
                    svm.setCacheSize(V_alphas.size(), cache_size);
                else
                    svm.setCacheMode(CacheMode.NONE);

                // Train model
                if (l == l_max)// first round, no warm start
                    svm.train(V_c, parallel);
                else// warm start
                {
                    svm.train(V_c, V_alphas.elements(), parallel);
                }
                early_models.put(c, svm);

                // Update larger set of alphas
                for (int i = 0; i < orig_index.size(); i++)
                    this.alphas[orig_index.getInt(i)] = svm.alphas[i];
            }
        }

        if (l_early == 0)// fully solve the problem! Refinement step was done implicitly in above loop
        {
            SVMnoBias svm = new SVMnoBias(getKernel());
            if (cache_size > 0)
                svm.setCacheSize(dataSet.size(), cache_size);
            else
                svm.setCacheMode(CacheMode.NONE);
            svm.train(dataSet, Arrays.copyOf(this.alphas, this.alphas.length), parallel);

            early_models.clear();
            early_models.put(0, svm);

            // Update all alphas
            for (int i = 0; i < N; i++)
                this.alphas[i] = svm.alphas[i];
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public DCSVM clone() {
        return new DCSVM(this);
    }

    /**
     * Sets the complexity parameter of SVM. The larger the C value the harder the
     * margin SVM will attempt to find. Lower values of C allow for more
     * misclassification errors.
     * 
     * @param C the soft margin parameter
     */
    @Parameter.WarmParameter(prefLowToHigh = true)
    public void setC(double C) {
        if (C <= 0)
            throw new ArithmeticException("C must be a positive constant");
        this.C = C;
    }

    /**
     * Returns the soft margin complexity parameter of the SVM
     * 
     * @return the complexity parameter of the SVM
     */
    public double getC() {
        return C;
    }
}
